#include "cell.h"
#include <qthread.h>
typedef struct fdotr_param {
  cellgrid_t *grid;
  mdstat_t *stat;
} fdotr_param_t;
#ifdef __sw_host__
extern void slave_force_fdotr_compute_cpe(fdotr_param_t*);
extern void slave_force_reset_cpe(cellgrid_t*);
void force_reset_sw(cellgrid_t *grid) {
  qthread_spawn(slave_force_reset_cpe, grid);
  qthread_join();
  // FOREACH_CELL(grid, cx, cy, cz, cell) {
  //   for (int i = 0; i < cell->natom; i ++) {
  //     vecset1(cell->f[i], 0);
  //   }
  // }
}
void force_fdotr_compute_sw(cellgrid_t *grid, mdstat_t *stat){
  fdotr_param_t pm;
  pm.grid = grid;
  pm.stat = stat;
  qthread_spawn(slave_force_fdotr_compute_cpe, &pm);
  qthread_join();
}
#endif
#ifdef __sw_slave__
extern "C"{
#include <qthread_slave.h>
}
#include "dma_macros_new.h"
#include "reg_reduce.h"
#include "swarch.h"
void force_fdotr_compute_cpe(fdotr_param_t *pm){
  dma_init();
  cellgrid_t lgrid;
  mdstat_t lstat;
  pe_get(pm->grid, &lgrid, sizeof(cellgrid_t));
  if (_MYID == 0) {
    pe_get(pm->stat, &lstat, sizeof(mdstat_t));
  } else {
    for (int i = 0; i < 6; i ++) {
      lstat.virial[i] = 0;
    }
  }
  dma_syn();
  
  FOREACH_CELL_CPE_RR(&lgrid, ii, jj, kk, cell){
    cellmeta_t imeta;
    pe_get(&cell->basis, &imeta, sizeof(cellmeta_t));
    dma_syn();
    vec<real> x[CELL_CAP], f[CELL_CAP];
    pe_get(cell->x, x, sizeof(vec<real>) * imeta.natom);
    pe_get(cell->f, f, sizeof(vec<real>) * imeta.natom);
    dma_syn();
    for (int i = 0; i < imeta.natom; i ++){
      vec<real> xi;
      vecaddv(xi, imeta.basis, x[i]);
      lstat.virial[0] += xi.x * f[i].x;
      lstat.virial[1] += xi.y * f[i].y;
      lstat.virial[2] += xi.z * f[i].z;
      lstat.virial[3] += xi.x * f[i].y;
      lstat.virial[4] += xi.x * f[i].z;
      lstat.virial[5] += xi.y * f[i].z;
    }
  }
  double dvir[8];
  for (int i = 0; i < 6; i ++) {
    dvir[i] = lstat.virial[i];
  }
  reg_reduce_inplace_doublev4(dvir, 2);
  if (_MYID == 0){
    for (int i = 0; i < 6; i ++) {
      lstat.virial[i] = dvir[i];
    }
    pe_put(pm->stat, &lstat, sizeof(mdstat_t));
    dma_syn();
  }
}
void force_reset_cpe(cellgrid_t *grid) {
  dma_init();
  cellgrid_t lgrid;
  pe_get(grid, &lgrid, sizeof(cellgrid_t));
  dma_syn();
  vec<real> f[CELL_CAP];
  for (int i = 0; i < CELL_CAP; i ++){
    vecset1(f[i], 0);
  }
  FOREACH_CELL_CPE_RR(&lgrid, ii, jj, kk, cell){
    cellmeta_t imeta;
    pe_get(&cell->basis, &imeta, sizeof(cellmeta_t));
    dma_syn();
    pe_put(cell->f, f, sizeof(vec<real>) * imeta.natom);
    dma_syn();
  }
}
#endif